import collections
import multiprocessing as mp
import random

import dill as pickle
import torch


class SelfSupervisedBuffer:
    # 自监督学习缓冲区，存储状态、动作和下一个状态，数据在多个进程间共享
    def __init__(self, buffer_size):
        self.buffer_size = int(buffer_size)
        # self.queue = mp.Queue(maxsize=self.buffer_size)
        # self.lock = mp.Lock()
        self.manager = mp.Manager()
        self.buffer = self.manager.list()
        self.lock = self.manager.Lock()
        self.use_up_sample = False  # 低效，不使用

    def add(self, states, actions, next_states):
        # 添加多个状态-动作对到缓冲区，没有考虑状态-动作对的顺序
        with self.lock:
            for i in range(len(states)):
                # if self.queue.full():
                #     self.queue.get()
                # self.queue.put((states[i], actions[i], next_states[i]))
                if len(self.buffer) >= self.buffer_size:
                    self.buffer.pop(0)
                self.buffer.append((states[i], actions[i], next_states[i]))

    def get_list(self):
        with self.lock:
            buffer_list = list(self.buffer)
            return buffer_list
            # if batch_size > len(buffer_list):
            #     raise ValueError("Sample size greater than population size.")
            # return random.sample(buffer_list, batch_size)

    def _up_sample_generator(self, buffer_list):
        """
        对缓冲区中的数据进行上采样，使每个动作类别的数量相等，使用生成器动态生成数据
        """
        states, actions, next_states = zip(*buffer_list)
        action_counts = collections.Counter(actions)
        max_count = max(action_counts.values())

        action_indices_map = {action: [i for i, a in enumerate(actions) if a == action] for action in action_counts}

        while True:
            balanced_buffer_list = []
            for action, indices in action_indices_map.items():
                additional_indices = random.choices(indices, k=max_count - len(indices))
                balanced_buffer_list.extend(indices + additional_indices)
            random.shuffle(balanced_buffer_list)
            for idx in balanced_buffer_list:
                yield buffer_list[idx]

    def get_torch_dataset(self):
        with self.lock:
            buffer_list = list(self.buffer)
            if self.use_up_sample:
                sample_generator = self._up_sample_generator(buffer_list)
                sampled_buffer_list = [next(sample_generator) for _ in range(len(buffer_list))]
            else:
                sampled_buffer_list = buffer_list

            states, actions, next_states = zip(*sampled_buffer_list)

            states = torch.stack(states)
            actions = torch.stack(actions)
            next_states = torch.stack(next_states)
            return torch.utils.data.TensorDataset(states, actions, next_states)

    def clear(self):
        with self.lock:
            self.buffer[:] = []
            # while not self.queue.empty():
            #     self.queue.get()

    def size(self):
        with self.lock:
            return len(self.buffer)

    def save(self, file_path):
        with self.lock:
            with open(file_path, "wb") as f:
                pickle.dump(list(self.buffer), f)

    def load(self, file_path):
        with self.lock:
            with open(file_path, "rb") as f:
                data = pickle.load(f)
                self.buffer[:] = data
